# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Restrained Amber Minimization of a structure."""

import io
import time
from typing import Collection, Optional, Sequence

from absl import logging
import ml_collections
import numpy as np
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure

from alphafold.common import protein
from alphafold.common import residue_constants
from alphafold.model import folding
from alphafold.relax import cleanup
from alphafold.relax import utils


ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms


def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
  """Returns True if the atom will be restrained by the given restraint set."""

  if rset == "non_hydrogen":
    return atom.element.name != "hydrogen"
  elif rset == "c_alpha":
    return atom.name == "CA"


def _add_restraints(
    system: openmm.System,
    reference_pdb: openmm_app.PDBFile,
    stiffness: unit.Unit,
    rset: str,
    exclude_residues: Sequence[int]):
  """Adds a harmonic potential that restrains the end-to-end distance."""
  assert rset in ["non_hydrogen", "c_alpha"]

  force = openmm.CustomExternalForce(
      "0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
  force.addGlobalParameter("k", stiffness)
  for p in ["x0", "y0", "z0"]:
    force.addPerParticleParameter(p)

  for i, atom in enumerate(reference_pdb.topology.atoms()):
    if atom.residue.index in exclude_residues:
      continue
    if will_restrain(atom, rset):
      force.addParticle(i, reference_pdb.positions[i])
  logging.info("Restraining %d / %d particles.",
               force.getNumParticles(), system.getNumParticles())
  system.addForce(force)


def _openmm_minimize(
    pdb_str: str,
    max_iterations: int,
    tolerance: unit.Unit,
    stiffness: unit.Unit,
    restraint_set: str,
    exclude_residues: Sequence[int]):
  """Minimize energy via openmm."""

  pdb_file = io.StringIO(pdb_str)
  pdb = openmm_app.PDBFile(pdb_file)

  force_field = openmm_app.ForceField("amber99sb.xml")
  constraints = openmm_app.HBonds
  system = force_field.createSystem(
      pdb.topology, constraints=constraints)
  if stiffness > 0 * ENERGY / (LENGTH**2):
    _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)

  integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
  platform = openmm.Platform.getPlatformByName("CPU")
  simulation = openmm_app.Simulation(
      pdb.topology, system, integrator, platform)
  simulation.context.setPositions(pdb.positions)

  ret = {}
  state = simulation.context.getState(getEnergy=True, getPositions=True)
  ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
  ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
  simulation.minimizeEnergy(maxIterations=max_iterations,
                            tolerance=tolerance)
  state = simulation.context.getState(getEnergy=True, getPositions=True)
  ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
  ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
  ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
  return ret


def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
  """Returns a pdb string provided OpenMM topology and positions."""
  with io.StringIO() as f:
    openmm_app.PDBFile.writeFile(topology, positions, f)
    return f.getvalue()


def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
  """Checks that no atom positions have been altered by cleaning."""
  cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
  reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))

  cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
  ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))

  for ref_res, cl_res in zip(reference.topology.residues(),
                             cleaned.topology.residues()):
    assert ref_res.name == cl_res.name
    for rat in ref_res.atoms():
      for cat in cl_res.atoms():
        if cat.name == rat.name:
          if not np.array_equal(cl_xyz[cat.index], ref_xyz[rat.index]):
            raise ValueError(f"Coordinates of cleaned atom {cat} do not match "
                             f"coordinates of reference atom {rat}.")


def _check_residues_are_well_defined(prot: protein.Protein):
  """Checks that all residues contain non-empty atom sets."""
  if (prot.atom_mask.sum(axis=-1) == 0).any():
    raise ValueError("Amber minimization can only be performed on proteins with"
                     " well-defined residues. This protein contains at least"
                     " one residue with no atoms.")


def _check_atom_mask_is_ideal(prot):
  """Sanity-check the atom mask is ideal, up to a possible OXT."""
  atom_mask = prot.atom_mask
  ideal_atom_mask = protein.ideal_atom_mask(prot)
  utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)


def clean_protein(
    prot: protein.Protein,
    checks: bool = True):
  """Adds missing atoms to Protein instance.

  Args:
    prot: A `protein.Protein` instance.
    checks: A `bool` specifying whether to add additional checks to the cleaning
      process.

  Returns:
    pdb_string: A string of the cleaned protein.
  """
  _check_atom_mask_is_ideal(prot)

  # Clean pdb.
  prot_pdb_string = protein.to_pdb(prot)
  pdb_file = io.StringIO(prot_pdb_string)
  alterations_info = {}
  fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
  fixed_pdb_file = io.StringIO(fixed_pdb)
  pdb_structure = PdbStructure(fixed_pdb_file)
  cleanup.clean_structure(pdb_structure, alterations_info)

  logging.info("alterations info: %s", alterations_info)

  # Write pdb file of cleaned structure.
  as_file = openmm_app.PDBFile(pdb_structure)
  pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
  if checks:
    _check_cleaned_atoms(pdb_string, prot_pdb_string)
  return pdb_string


def make_atom14_positions(prot):
  """Constructs denser atom positions (14 dimensions instead of 37)."""
  restype_atom14_to_atom37 = []  # mapping (restype, atom14) --> atom37
  restype_atom37_to_atom14 = []  # mapping (restype, atom37) --> atom14
  restype_atom14_mask = []

  for rt in residue_constants.restypes:
    atom_names = residue_constants.restype_name_to_atom14_names[
        residue_constants.restype_1to3[rt]]

    restype_atom14_to_atom37.append([
        (residue_constants.atom_order[name] if name else 0)
        for name in atom_names
    ])

    atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
    restype_atom37_to_atom14.append([
        (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
        for name in residue_constants.atom_types
    ])

    restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])

  # Add dummy mapping for restype 'UNK'.
  restype_atom14_to_atom37.append([0] * 14)
  restype_atom37_to_atom14.append([0] * 37)
  restype_atom14_mask.append([0.] * 14)

  restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
  restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
  restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)

  # Create the mapping for (residx, atom14) --> atom37, i.e. an array
  # with shape (num_res, 14) containing the atom37 indices for this protein.
  residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
  residx_atom14_mask = restype_atom14_mask[prot["aatype"]]

  # Create a mask for known ground truth positions.
  residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
      prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32)

  # Gather the ground truth positions.
  residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
      np.take_along_axis(prot["all_atom_positions"],
                         residx_atom14_to_atom37[..., None],
                         axis=1))

  prot["atom14_atom_exists"] = residx_atom14_mask
  prot["atom14_gt_exists"] = residx_atom14_gt_mask
  prot["atom14_gt_positions"] = residx_atom14_gt_positions

  prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37

  # Create the gather indices for mapping back.
  residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
  prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14

  # Create the corresponding mask.
  restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
  for restype, restype_letter in enumerate(residue_constants.restypes):
    restype_name = residue_constants.restype_1to3[restype_letter]
    atom_names = residue_constants.residue_atoms[restype_name]
    for atom_name in atom_names:
      atom_type = residue_constants.atom_order[atom_name]
      restype_atom37_mask[restype, atom_type] = 1

  residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
  prot["atom37_atom_exists"] = residx_atom37_mask

  # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
  # alternative ground truth coordinates where the naming is swapped
  restype_3 = [
      residue_constants.restype_1to3[res] for res in residue_constants.restypes
  ]
  restype_3 += ["UNK"]

  # Matrices for renaming ambiguous atoms.
  all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
  for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
    correspondences = np.arange(14)
    for source_atom_swap, target_atom_swap in swap.items():
      source_index = residue_constants.restype_name_to_atom14_names[
          resname].index(source_atom_swap)
      target_index = residue_constants.restype_name_to_atom14_names[
          resname].index(target_atom_swap)
      correspondences[source_index] = target_index
      correspondences[target_index] = source_index
      renaming_matrix = np.zeros((14, 14), dtype=np.float32)
      for index, correspondence in enumerate(correspondences):
        renaming_matrix[index, correspondence] = 1.
    all_matrices[resname] = renaming_matrix.astype(np.float32)
  renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])

  # Pick the transformation matrices for the given residue sequence
  # shape (num_res, 14, 14).
  renaming_transform = renaming_matrices[prot["aatype"]]

  # Apply it to the ground truth positions. shape (num_res, 14, 3).
  alternative_gt_positions = np.einsum("rac,rab->rbc",
                                       residx_atom14_gt_positions,
                                       renaming_transform)
  prot["atom14_alt_gt_positions"] = alternative_gt_positions

  # Create the mask for the alternative ground truth (differs from the
  # ground truth mask, if only one of the atoms in an ambiguous pair has a
  # ground truth position).
  alternative_gt_mask = np.einsum("ra,rab->rb",
                                  residx_atom14_gt_mask,
                                  renaming_transform)

  prot["atom14_alt_gt_exists"] = alternative_gt_mask

  # Create an ambiguous atoms mask.  shape: (21, 14).
  restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
  for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
    for atom_name1, atom_name2 in swap.items():
      restype = residue_constants.restype_order[
          residue_constants.restype_3to1[resname]]
      atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(
          atom_name1)
      atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(
          atom_name2)
      restype_atom14_is_ambiguous[restype, atom_idx1] = 1
      restype_atom14_is_ambiguous[restype, atom_idx2] = 1

  # From this create an ambiguous_mask for the given sequence.
  prot["atom14_atom_is_ambiguous"] = (
      restype_atom14_is_ambiguous[prot["aatype"]])

  return prot


def find_violations(prot_np: protein.Protein):
  """Analyzes a protein and returns structural violation information.

  Args:
    prot_np: A protein.

  Returns:
    violations: A `dict` of structure components with structural violations.
    violation_metrics: A `dict` of violation metrics.
  """
  batch = {
      "aatype": prot_np.aatype,
      "all_atom_positions": prot_np.atom_positions.astype(np.float32),
      "all_atom_mask": prot_np.atom_mask.astype(np.float32),
      "residue_index": prot_np.residue_index,
  }

  batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
  batch = make_atom14_positions(batch)

  violations = folding.find_structural_violations(
      batch=batch,
      atom14_pred_positions=batch["atom14_gt_positions"],
      config=ml_collections.ConfigDict(
          {"violation_tolerance_factor": 12,  # Taken from model config.
           "clash_overlap_tolerance": 1.5,  # Taken from model config.
          }))
  violation_metrics = folding.compute_violation_metrics(
      batch=batch,
      atom14_pred_positions=batch["atom14_gt_positions"],
      violations=violations,
  )

  return violations, violation_metrics


def get_violation_metrics(prot: protein.Protein):
  """Computes violation and alignment metrics."""
  structural_violations, struct_metrics = find_violations(prot)
  violation_idx = np.flatnonzero(
      structural_violations["total_per_residue_violations_mask"])

  struct_metrics["residue_violations"] = violation_idx
  struct_metrics["num_residue_violations"] = len(violation_idx)
  struct_metrics["structural_violations"] = structural_violations
  return struct_metrics


def _run_one_iteration(
    *,
    pdb_string: str,
    max_iterations: int,
    tolerance: float,
    stiffness: float,
    restraint_set: str,
    max_attempts: int,
    exclude_residues: Optional[Collection[int]] = None):
  """Runs the minimization pipeline.

  Args:
    pdb_string: A pdb string.
    max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
    A value of 0 specifies no limit.
    tolerance: kcal/mol, the energy tolerance of L-BFGS.
    stiffness: kcal/mol A**2, spring constant of heavy atom restraining
      potential.
    restraint_set: The set of atoms to restrain.
    max_attempts: The maximum number of minimization attempts.
    exclude_residues: An optional list of zero-indexed residues to exclude from
        restraints.

  Returns:
    A `dict` of minimization info.
  """
  exclude_residues = exclude_residues or []

  # Assign physical dimensions.
  tolerance = tolerance * ENERGY
  stiffness = stiffness * ENERGY / (LENGTH**2)

  start = time.time()
  minimized = False
  attempts = 0
  while not minimized and attempts < max_attempts:
    attempts += 1
    try:
      logging.info("Minimizing protein, attempt %d of %d.",
                   attempts, max_attempts)
      ret = _openmm_minimize(
          pdb_string, max_iterations=max_iterations,
          tolerance=tolerance, stiffness=stiffness,
          restraint_set=restraint_set,
          exclude_residues=exclude_residues)
      minimized = True
    except Exception as e:  # pylint: disable=broad-except
      logging.info(e)
  if not minimized:
    raise ValueError(f"Minimization failed after {max_attempts} attempts.")
  ret["opt_time"] = time.time() - start
  ret["min_attempts"] = attempts
  return ret


def run_pipeline(
    prot: protein.Protein,
    stiffness: float,
    max_outer_iterations: int = 1,
    place_hydrogens_every_iteration: bool = True,
    max_iterations: int = 0,
    tolerance: float = 2.39,
    restraint_set: str = "non_hydrogen",
    max_attempts: int = 100,
    checks: bool = True,
    exclude_residues: Optional[Sequence[int]] = None):
  """Run iterative amber relax.

  Successive relax iterations are performed until all violations have been
  resolved. Each iteration involves a restrained Amber minimization, with
  restraint exclusions determined by violation-participating residues.

  Args:
    prot: A protein to be relaxed.
    stiffness: kcal/mol A**2, the restraint stiffness.
    max_outer_iterations: The maximum number of iterative minimization.
    place_hydrogens_every_iteration: Whether hydrogens are re-initialized
        prior to every minimization.
    max_iterations: An `int` specifying the maximum number of L-BFGS steps
        per relax iteration. A value of 0 specifies no limit.
    tolerance: kcal/mol, the energy tolerance of L-BFGS.
        The default value is the OpenMM default.
    restraint_set: The set of atoms to restrain.
    max_attempts: The maximum number of minimization attempts per iteration.
    checks: Whether to perform cleaning checks.
    exclude_residues: An optional list of zero-indexed residues to exclude from
        restraints.

  Returns:
    out: A dictionary of output values.
  """

  # `protein.to_pdb` will strip any poorly-defined residues so we need to
  # perform this check before `clean_protein`.
  _check_residues_are_well_defined(prot)
  pdb_string = clean_protein(prot, checks=checks)

  exclude_residues = exclude_residues or []
  exclude_residues = set(exclude_residues)
  violations = np.inf
  iteration = 0

  while violations > 0 and iteration < max_outer_iterations:
    ret = _run_one_iteration(
        pdb_string=pdb_string,
        exclude_residues=exclude_residues,
        max_iterations=max_iterations,
        tolerance=tolerance,
        stiffness=stiffness,
        restraint_set=restraint_set,
        max_attempts=max_attempts)
    prot = protein.from_pdb_string(ret["min_pdb"])
    if place_hydrogens_every_iteration:
      pdb_string = clean_protein(prot, checks=True)
    else:
      pdb_string = ret["min_pdb"]
    ret.update(get_violation_metrics(prot))
    ret.update({
        "num_exclusions": len(exclude_residues),
        "iteration": iteration,
    })
    violations = ret["violations_per_residue"]
    exclude_residues = exclude_residues.union(ret["residue_violations"])

    logging.info("Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
                 "num residue violations %d num residue exclusions %d ",
                 ret["einit"], ret["efinal"], ret["opt_time"],
                 ret["num_residue_violations"], ret["num_exclusions"])
    iteration += 1
  return ret


def get_initial_energies(pdb_strs: Sequence[str],
                         stiffness: float = 0.0,
                         restraint_set: str = "non_hydrogen",
                         exclude_residues: Optional[Sequence[int]] = None):
  """Returns initial potential energies for a sequence of PDBs.

  Assumes the input PDBs are ready for minimization, and all have the same
  topology.
  Allows time to be saved by not pdbfixing / rebuilding the system.

  Args:
    pdb_strs: List of PDB strings.
    stiffness: kcal/mol A**2, spring constant of heavy atom restraining
        potential.
    restraint_set: Which atom types to restrain.
    exclude_residues: An optional list of zero-indexed residues to exclude from
        restraints.

  Returns:
    A list of initial energies in the same order as pdb_strs.
  """
  exclude_residues = exclude_residues or []

  openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p)))
                 for p in pdb_strs]
  force_field = openmm_app.ForceField("amber99sb.xml")
  system = force_field.createSystem(openmm_pdbs[0].topology,
                                    constraints=openmm_app.HBonds)
  stiffness = stiffness * ENERGY / (LENGTH**2)
  if stiffness > 0 * ENERGY / (LENGTH**2):
    _add_restraints(system, openmm_pdbs[0], stiffness, restraint_set,
                    exclude_residues)
  simulation = openmm_app.Simulation(openmm_pdbs[0].topology,
                                     system,
                                     openmm.LangevinIntegrator(0, 0.01, 0.0),
                                     openmm.Platform.getPlatformByName("CPU"))
  energies = []
  for pdb in openmm_pdbs:
    try:
      simulation.context.setPositions(pdb.positions)
      state = simulation.context.getState(getEnergy=True)
      energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
    except Exception as e:  # pylint: disable=broad-except
      logging.error("Error getting initial energy, returning large value %s", e)
      energies.append(unit.Quantity(1e20, ENERGY))
  return energies
